﻿#pragma kernel UpdateChainMesh

#include "PathFrame.cginc"

struct chainRendererData
{
    int modifierOffset;
    float twistAnchor;
    float twist;
    uint usesOrientedParticles;

    float4 scale;
};

struct chunkData
{
    int rendererIndex;
    int offset; // index of the first element for each chunk.
};

struct linkModifier
{
    float3 translation;
    float3 scale;
    float3 rotation;

    void Clear()
    {
        translation = float3(0,0,0);
        scale =  float3(1,1,1);
        rotation = float3(0,0,0);
    }
};

StructuredBuffer<chainRendererData> rendererData;
StructuredBuffer<chunkData> chunksData;
StructuredBuffer<linkModifier> modifiers;
StructuredBuffer<int2> elements;

StructuredBuffer<float4> renderablePositions;
StructuredBuffer<quaternion> renderableOrientations;
StructuredBuffer<float4> principalRadii;
StructuredBuffer<float4> colors;
float4x4 solverToWorld;

RWStructuredBuffer<float4x4> instanceTransforms;
RWStructuredBuffer<float4x4> invInstanceTransforms;
RWStructuredBuffer<float4> instanceColors;

uint chunkCount;

[numthreads(32, 1, 1)]
void UpdateChainMesh (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
    if (i >= chunkCount) return;

    int firstIndex = i > 0 ? chunksData[i - 1].offset : 0;
    uint elementCount = chunksData[i].offset - firstIndex;

    int rendererIndex = chunksData[i].rendererIndex;
    chainRendererData renderer = rendererData[rendererIndex];

    int firstModifier = rendererIndex > 0 ? rendererData[rendererIndex - 1].modifierOffset : 0;
    uint modifierCount = renderer.modifierOffset - firstModifier;

    linkModifier modifier;
    modifier.Clear();

    pathFrame frame;
    frame.Reset();

    float twist = -renderer.twist * elementCount * renderer.twistAnchor;
    frame.SetTwist(twist);

    // parallel transport:
    for (uint m = 0; m < elementCount; ++m)
    {
        if (modifierCount > 0)
            modifier = modifiers[firstModifier + m % modifierCount];

        int index = firstIndex + m;
        float4 pos     = renderablePositions[elements[index].x];
        float4 nextPos = renderablePositions[elements[index].y];
        float4 vect = nextPos - pos;
        float3 tangent = normalizesafe(vect.xyz);

        if (renderer.usesOrientedParticles == 1)
        {
            frame.Transport(nextPos.xyz, tangent, rotate_vector(renderableOrientations[elements[index].x], float3(0, 1, 0)), twist);
            twist += renderer.twist;
        }
        else
            frame.Transport(nextPos.xyz, tangent, renderer.twist);

        if (modifierCount > 0)
            modifier = modifiers[firstModifier + m % (uint)modifierCount];

        quaternion rotation = q_look_at(frame.tangent, frame.normal); 
        float3 position = (pos + vect * 0.5f).xyz + rotate_vector(rotation, modifier.translation);
        float3 scale = principalRadii[elements[index].x].x * 2 * renderer.scale.xyz * modifier.scale;

        rotation = qmul(rotation, q_eulerXYZ(radians(modifier.rotation)));

        instanceTransforms[index] = mul(solverToWorld, TRS(position,rotation,scale));
        invInstanceTransforms[index] = Inverse(instanceTransforms[index]);
        instanceColors[index] = (colors[elements[index].x] + colors[elements[index].x]) * 0.5;
    }
}